探索跨文化艺术的条件 KNN

您所在的位置:网站首页 sql条件case when 探索跨文化艺术的条件 KNN

探索跨文化艺术的条件 KNN

2023-05-28 11:40| 来源: 网络整理| 查看: 265

通过快速、有条件、k-最近邻居探索跨文化和媒体的艺术 项目 05/23/2023

此笔记本用作通过 k-nearest-neighbors 查找匹配的指南。 我们设置了代码,允许从纽约大都会艺术博物馆和阿姆斯特丹的国家博物馆收集涉及文化和艺术媒介的查询。

先决条件 将笔记本附加到湖屋。 在左侧,选择“ 添加” 以添加现有 lakehouse 或创建 lakehouse。 BallTree 概述

kNN 模型背后的结构是 BallTree,它是递归二元树,其中每个节点 (或“ball”) 包含要查询的数据点的分区。 生成 BallTree 涉及将数据点分配给“球”,该“球”的中心最接近 (相对于特定指定特征) ,从而形成一个结构,该结构允许二叉树状遍历,并有助于在 BallTree 叶中查找 k 最近的邻居。

设置

导入必要的 Python 库并准备数据集。

from synapse.ml.core.platform import * if running_on_binder(): from IPython import get_ipython from pyspark.sql.types import BooleanType from pyspark.sql.types import * from pyspark.ml.feature import Normalizer from pyspark.sql.functions import lit, array, array_contains, udf, col, struct from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel from PIL import Image from io import BytesIO import requests import numpy as np import matplotlib.pyplot as plt from pyspark.sql import SparkSession # Bootstrap Spark Session spark = SparkSession.builder.getOrCreate()

我们的数据集来自一个表,其中包含来自大都会和 Rijks 博物馆的艺术品信息。 架构如下所示:

id:艺术品的唯一标识符 示例 Met ID: 388395 示例 Rijks ID: SK-A-2344 标题:艺术作品标题,写在博物馆的数据库 艺术家:艺术作品艺术家,如博物馆的数据库所写 Thumbnail_Url:艺术作品 JPEG 缩略图的位置 Image_Url 大都会/Rijks 网站上托管的艺术作品图像的位置 文化:艺术作品所属的文化类别 示例文化类别: 拉丁美洲、 埃及等。 分类:艺术作品所属的介质类别 示例介质类别: 木制品、 绘画等。 Museum_Page:大都会/Rijks 网站上的艺术作品链接 Norm_Features:嵌入艺术作品图像 博物馆:指定作品源自哪个博物馆 # loads the dataset and the two trained CKNN models for querying by medium and culture df = spark.read.parquet( "wasbs://[email protected]/met_and_rijks.parquet" ) display(df.drop("Norm_Features")) 定义要查询的类别

我们将使用两个 kNN 模型:一个用于区域性模型,一个用于中等模型。

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs', "metalwork", # "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"] mediums = ["paintings", "glass", "ceramics"] # cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)', # 'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek', # 'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian', # 'spanish', 'swiss', 'various'] cultures = ["japanese", "american", "african (general)"] # Uncomment the above for more robust and large scale searches! classes = cultures + mediums medium_set = set(mediums) culture_set = set(cultures) selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"} small_df = df.where( udf( lambda medium, culture, id_val: (medium in medium_set) or (culture in culture_set) or (id_val in selected_ids), BooleanType(), )("Classification", "Culture", "id") ) small_df.count() 定义和拟合 ConditionalKNN 模型

我们为介质列和区域性列创建 ConditionalKNN 模型;每个模型采用输出列、特征列 (特征向量) 、输出列) 下的值列 (单元格值,标签列 (相应 KNN 根据) 条件的质量。

medium_cknn = ( ConditionalKNN() .setOutputCol("Matches") .setFeaturesCol("Norm_Features") .setValuesCol("Thumbnail_Url") .setLabelCol("Classification") .fit(small_df) ) culture_cknn = ( ConditionalKNN() .setOutputCol("Matches") .setFeaturesCol("Norm_Features") .setValuesCol("Thumbnail_Url") .setLabelCol("Culture") .fit(small_df) ) 定义匹配和可视化方法

初始数据集和类别设置后,我们准备用于查询和可视化条件 kNN 结果的方法。

addMatches() 创建一个数据帧,其中包含每个类别的少量匹配项。

def add_matches(classes, cknn, df): results = df for label in classes: results = cknn.transform( results.withColumn("conditioner", array(lit(label))) ).withColumnRenamed("Matches", "Matches_{}".format(label)) return results

plot_urls() 调用 plot_img 以将每个类别的顶级匹配项可视化到网格中。

def plot_img(axis, url, title): try: response = requests.get(url) img = Image.open(BytesIO(response.content)).convert("RGB") axis.imshow(img, aspect="equal") except: pass if title is not None: axis.set_title(title, fontsize=4) axis.axis("off") def plot_urls(url_arr, titles, filename): nx, ny = url_arr.shape plt.figure(figsize=(nx * 5, ny * 5), dpi=1600) fig, axes = plt.subplots(ny, nx) # reshape required in the case of 1 image query if len(axes.shape) == 1: axes = axes.reshape(1, -1) for i in range(nx): for j in range(ny): if j == 0: plot_img(axes[j, i], url_arr[i, j], titles[i]) else: plot_img(axes[j, i], url_arr[i, j], None) plt.savefig(filename, dpi=1600) # saves the results as a PNG display(plt.show()) 汇总

我们定义 test_all() 用于获取数据、CKNN 模型、要查询的艺术 ID 值以及用于保存输出可视化效果的文件路径。 之前已训练和加载介质和区域性模型。

# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png def test_all(data, cknn_medium, cknn_culture, test_ids, root): is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType()) test_df = data.where(is_nice_obj("id")) results_df_medium = add_matches(mediums, cknn_medium, test_df) results_df_culture = add_matches(cultures, cknn_culture, results_df_medium) results = results_df_culture.collect() original_urls = [row["Thumbnail_Url"] for row in results] culture_urls = [ [row["Matches_{}".format(label)][0]["value"] for row in results] for label in cultures ] culture_url_arr = np.array([original_urls] + culture_urls)[:, :] plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png") medium_urls = [ [row["Matches_{}".format(label)][0]["value"] for row in results] for label in mediums ] medium_url_arr = np.array([original_urls] + medium_urls)[:, :] plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png") return results_df_culture 演示

以下单元格在给定所需的图像 ID 和文件名的情况下执行批处理查询,以保存可视化效果。

# sample query result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".") 后续步骤 如何将 ONNX 与 SynapseML 配合使用 - 深度学习 如何使用内核 SHAP 解释表格分类模型 如何使用 SynapseML 进行多变量异常检测


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3